"""Config flow for Local LLaMA Conversation integration."""
from __future__ import annotations

import os
import sys
import logging
import requests
from types import MappingProxyType
from typing import Any
from abc import ABC, abstractmethod

import voluptuous as vol

from homeassistant import config_entries
from homeassistant.core import HomeAssistant
from homeassistant.const import CONF_HOST, CONF_PORT, CONF_SSL, UnitOfTime
from homeassistant.data_entry_flow import (
    AbortFlow,
    FlowHandler,
    FlowManager,
    FlowResult,
)
from homeassistant.helpers.selector import (
    NumberSelector,
    NumberSelectorConfig,
    NumberSelectorMode,
    TemplateSelector,
    SelectSelector,
    SelectSelectorConfig,
    SelectSelectorMode,
    TextSelector,
    TextSelectorConfig,
    BooleanSelector,
    BooleanSelectorConfig,
)
from homeassistant.util.package import is_installed
from importlib.metadata import version

from .utils import download_model_from_hf, install_llama_cpp_python
from .const import (
    CONF_CHAT_MODEL,
    CONF_MAX_TOKENS,
    CONF_PROMPT,
    CONF_TEMPERATURE,
    CONF_TOP_K,
    CONF_TOP_P,
    CONF_MIN_P,
    CONF_TYPICAL_P,
    CONF_REQUEST_TIMEOUT,
    CONF_BACKEND_TYPE,
    CONF_SELECTED_LANGUAGE,
    CONF_SELECTED_LANGUAGE_OPTIONS,
    CONF_DOWNLOADED_MODEL_FILE,
    CONF_DOWNLOADED_MODEL_QUANTIZATION,
    CONF_DOWNLOADED_MODEL_QUANTIZATION_OPTIONS,
    CONF_PROMPT_TEMPLATE,
    CONF_ENABLE_FLASH_ATTENTION,
    CONF_USE_GBNF_GRAMMAR,
    CONF_GBNF_GRAMMAR_FILE,
    CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
    CONF_ALLOWED_SERVICE_CALL_ARGUMENTS,
    CONF_TEXT_GEN_WEBUI_PRESET,
    CONF_REFRESH_SYSTEM_PROMPT,
    CONF_REMEMBER_CONVERSATION,
    CONF_REMEMBER_NUM_INTERACTIONS,
    CONF_PROMPT_CACHING_ENABLED,
    CONF_PROMPT_CACHING_INTERVAL,
    CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
    CONF_IN_CONTEXT_EXAMPLES_FILE,
    CONF_NUM_IN_CONTEXT_EXAMPLES,
    CONF_OPENAI_API_KEY,
    CONF_TEXT_GEN_WEBUI_ADMIN_KEY,
    CONF_SERVICE_CALL_REGEX,
    CONF_REMOTE_USE_CHAT_ENDPOINT,
    CONF_TEXT_GEN_WEBUI_CHAT_MODE,
    CONF_OLLAMA_KEEP_ALIVE_MIN,
    CONF_OLLAMA_JSON_MODE,
    CONF_CONTEXT_LENGTH,
    CONF_BATCH_SIZE,
    CONF_THREAD_COUNT,
    CONF_BATCH_THREAD_COUNT,
    DEFAULT_CHAT_MODEL,
    DEFAULT_PORT,
    DEFAULT_SSL,
    DEFAULT_MAX_TOKENS,
    PERSONA_PROMPTS,
    DEFAULT_PROMPT_BASE,
    DEFAULT_TEMPERATURE,
    DEFAULT_TOP_K,
    DEFAULT_TOP_P,
    DEFAULT_MIN_P,
    DEFAULT_TYPICAL_P,
    DEFAULT_REQUEST_TIMEOUT,
    DEFAULT_BACKEND_TYPE,
    DEFAULT_DOWNLOADED_MODEL_QUANTIZATION,
    DEFAULT_PROMPT_TEMPLATE,
    DEFAULT_ENABLE_FLASH_ATTENTION,
    DEFAULT_USE_GBNF_GRAMMAR,
    DEFAULT_GBNF_GRAMMAR_FILE,
    DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
    DEFAULT_ALLOWED_SERVICE_CALL_ARGUMENTS,
    DEFAULT_REFRESH_SYSTEM_PROMPT,
    DEFAULT_REMEMBER_CONVERSATION,
    DEFAULT_REMEMBER_NUM_INTERACTIONS,
    DEFAULT_PROMPT_CACHING_ENABLED,
    DEFAULT_PROMPT_CACHING_INTERVAL,
    DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
    DEFAULT_IN_CONTEXT_EXAMPLES_FILE,
    DEFAULT_NUM_IN_CONTEXT_EXAMPLES,
    DEFAULT_SERVICE_CALL_REGEX,
    DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
    DEFAULT_TEXT_GEN_WEBUI_CHAT_MODE,
    DEFAULT_OLLAMA_KEEP_ALIVE_MIN,
    DEFAULT_OLLAMA_JSON_MODE,
    DEFAULT_CONTEXT_LENGTH,
    DEFAULT_BATCH_SIZE,
    DEFAULT_THREAD_COUNT,
    DEFAULT_BATCH_THREAD_COUNT,
    BACKEND_TYPE_LLAMA_HF,
    BACKEND_TYPE_LLAMA_EXISTING,
    BACKEND_TYPE_TEXT_GEN_WEBUI,
    BACKEND_TYPE_GENERIC_OPENAI,
    BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER,
    BACKEND_TYPE_OLLAMA,
    PROMPT_TEMPLATE_DESCRIPTIONS,
    TEXT_GEN_WEBUI_CHAT_MODE_CHAT,
    TEXT_GEN_WEBUI_CHAT_MODE_INSTRUCT,
    TEXT_GEN_WEBUI_CHAT_MODE_CHAT_INSTRUCT,
    DOMAIN,
    DEFAULT_OPTIONS,
    OPTIONS_OVERRIDES,
    RECOMMENDED_CHAT_MODELS,
    EMBEDDED_LLAMA_CPP_PYTHON_VERSION
)

_LOGGER = logging.getLogger(__name__)

def is_local_backend(backend):
    return backend in [BACKEND_TYPE_LLAMA_EXISTING, BACKEND_TYPE_LLAMA_HF]

def STEP_INIT_DATA_SCHEMA(backend_type=None):
    return vol.Schema(
        {
            vol.Required(
                CONF_BACKEND_TYPE,
                default=backend_type if backend_type else DEFAULT_BACKEND_TYPE
            ): SelectSelector(SelectSelectorConfig(
                options=[ 
                    BACKEND_TYPE_LLAMA_HF, BACKEND_TYPE_LLAMA_EXISTING,
                    BACKEND_TYPE_TEXT_GEN_WEBUI,
                    BACKEND_TYPE_GENERIC_OPENAI,
                    BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER,
                    BACKEND_TYPE_OLLAMA
                ],
                translation_key=CONF_BACKEND_TYPE,
                multiple=False,
                mode=SelectSelectorMode.DROPDOWN,
            ))
        }
    )

def STEP_LOCAL_SETUP_EXISTING_DATA_SCHEMA(model_file=None, selected_language=None):
    return vol.Schema(
        {
            vol.Required(CONF_DOWNLOADED_MODEL_FILE, default=model_file if model_file else ""): str,
            vol.Required(CONF_SELECTED_LANGUAGE, default=selected_language if selected_language else "en"): SelectSelector(SelectSelectorConfig(
                options=CONF_SELECTED_LANGUAGE_OPTIONS,
                translation_key=CONF_SELECTED_LANGUAGE,
                multiple=False,
                mode=SelectSelectorMode.DROPDOWN,
            )),
        }
    )

def STEP_LOCAL_SETUP_DOWNLOAD_DATA_SCHEMA(*, chat_model=None, downloaded_model_quantization=None, selected_language=None):
    return vol.Schema(
        {
            vol.Required(CONF_CHAT_MODEL, default=chat_model if chat_model else DEFAULT_CHAT_MODEL): SelectSelector(SelectSelectorConfig(
                options=RECOMMENDED_CHAT_MODELS,
                custom_value=True,
                multiple=False,
                mode=SelectSelectorMode.DROPDOWN,
            )),
            vol.Required(CONF_DOWNLOADED_MODEL_QUANTIZATION, default=downloaded_model_quantization if downloaded_model_quantization else DEFAULT_DOWNLOADED_MODEL_QUANTIZATION): vol.In(CONF_DOWNLOADED_MODEL_QUANTIZATION_OPTIONS),
            vol.Required(CONF_SELECTED_LANGUAGE, default=selected_language if selected_language else "en"): SelectSelector(SelectSelectorConfig(
                options=CONF_SELECTED_LANGUAGE_OPTIONS,
                translation_key=CONF_SELECTED_LANGUAGE,
                multiple=False,
                mode=SelectSelectorMode.DROPDOWN,
            )),
        }
    )

def STEP_REMOTE_SETUP_DATA_SCHEMA(backend_type: str, *, host=None, port=None, ssl=None, chat_model=None, available_chat_models=[], selected_language=None):

    extra1, extra2 = ({}, {})
    default_port = DEFAULT_PORT

    if backend_type == BACKEND_TYPE_TEXT_GEN_WEBUI: 
        extra2[vol.Optional(CONF_TEXT_GEN_WEBUI_ADMIN_KEY)] = TextSelector(TextSelectorConfig(type="password"))

    elif backend_type == BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER:
        default_port = "8000"
    elif backend_type == BACKEND_TYPE_OLLAMA:
        default_port = "11434"

    return vol.Schema(
        {
            vol.Required(CONF_HOST, default=host if host else ""): str,
            vol.Required(CONF_PORT, default=port if port else default_port): str,
            vol.Required(CONF_SSL, default=ssl if ssl else DEFAULT_SSL): bool,
            vol.Required(CONF_CHAT_MODEL, default=chat_model if chat_model else DEFAULT_CHAT_MODEL): SelectSelector(SelectSelectorConfig(
                options=available_chat_models,
                custom_value=True,
                multiple=False,
                mode=SelectSelectorMode.DROPDOWN,
            )),
            vol.Required(CONF_SELECTED_LANGUAGE, default=selected_language if selected_language else "en"): SelectSelector(SelectSelectorConfig(
                options=CONF_SELECTED_LANGUAGE_OPTIONS,
                translation_key=CONF_SELECTED_LANGUAGE,
                multiple=False,
                mode=SelectSelectorMode.DROPDOWN,
            )),
            **extra1,
            vol.Optional(CONF_OPENAI_API_KEY): TextSelector(TextSelectorConfig(type="password")),
            **extra2
        }
    )


class BaseLlamaConversationConfigFlow(FlowHandler, ABC):
    """Represent the base config flow for Z-Wave JS."""

    @property
    @abstractmethod
    def flow_manager(self) -> FlowManager:
        """Return the flow manager of the flow."""

    @abstractmethod
    async def async_step_pick_backend(
        self, user_input: dict[str, Any] | None = None
    ) -> FlowResult:
        """ Select backend """

    @abstractmethod
    async def async_step_install_local_wheels(
        self, user_input: dict[str, Any] | None = None
    ) -> FlowResult:
        """ Install pre-built wheels """

    @abstractmethod
    async def async_step_local_model(
        self, user_input: dict[str, Any] | None = None
    ) -> FlowResult:
        """ Configure a local model """

    @abstractmethod
    async def async_step_remote_model(
        self, user_input: dict[str, Any] | None = None
    ) -> FlowResult:
        """ Configure a remote model """

    @abstractmethod
    async def async_step_model_parameters(
        self, user_input: dict[str, Any] | None = None
    ) -> FlowResult:
        """ Configure a remote model """

    @abstractmethod
    async def async_step_download(
        self, user_input: dict[str, Any] | None = None
    ) -> FlowResult:
        """ Download a model from HF """

    @abstractmethod
    async def async_step_finish(
        self, user_input: dict[str, Any] | None = None
    ) -> FlowResult:
        """ Finish configuration """

class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, domain=DOMAIN):
    """Handle a config flow for Local LLaMA Conversation."""

    VERSION = 1
    install_wheel_task = None
    install_wheel_error = None
    download_task = None
    download_error = None
    model_config: dict[str, Any]
    options: dict[str, Any]
    selected_language: str

    @property
    def flow_manager(self) -> config_entries.ConfigEntriesFlowManager:
        """Return the correct flow manager."""
        return self.hass.config_entries.flow

    def async_remove(self) -> None:
        if self.download_task:
            self.download_task.cancel()

    async def async_step_user(
        self, user_input: dict[str, Any] | None = None
    ) -> FlowResult:
        """Handle the initial step."""
        self.model_config = {}
        self.options = {}
        return await self.async_step_pick_backend()

    async def async_step_pick_backend(
        self, user_input: dict[str, Any] | None = None
    ) -> FlowResult:
        """Handle the initial step."""
        errors = {}

        schema = STEP_INIT_DATA_SCHEMA()

        if user_input:
            local_backend = is_local_backend(user_input[CONF_BACKEND_TYPE])
            self.model_config.update(user_input)
            if local_backend:
                if is_installed("llama-cpp-python") and version("llama-cpp-python") == EMBEDDED_LLAMA_CPP_PYTHON_VERSION:
                    return await self.async_step_local_model()
                else:
                    return await self.async_step_install_local_wheels()
            else:
                return await self.async_step_remote_model()
        elif self.install_wheel_error:
            errors["base"] = str(self.install_wheel_error)
            self.install_wheel_error = None

        return self.async_show_form(
            step_id="pick_backend", data_schema=schema, errors=errors, last_step=False
        )

    async def async_step_install_local_wheels(
      self, user_input: dict[str, Any] | None = None
    ) -> FlowResult:
        if not self.install_wheel_task:
            _LOGGER.debug("Queuing install task...")
            self.install_wheel_task = self.hass.async_add_executor_job(
                install_llama_cpp_python, self.hass.config.config_dir
            )

            return self.async_show_progress(
                progress_task=self.install_wheel_task,
                step_id="install_local_wheels",
                progress_action="install_local_wheels",
            )
        
        if self.install_wheel_task and not self.install_wheel_task.done():
            return self.async_show_progress(
                progress_task=self.install_wheel_task,
                step_id="install_local_wheels",
                progress_action="install_local_wheels",
            )

        install_exception = self.install_wheel_task.exception()
        if install_exception:
            _LOGGER.warning("Failed to install wheel: %s", repr(install_exception))
            self.install_wheel_error = "pip_wheel_error"
            next_step = "pick_backend"
        else:
            wheel_install_result = self.install_wheel_task.result()
            if not wheel_install_result:
                self.install_wheel_error = "pip_wheel_error"
                next_step = "pick_backend"
            else:
                _LOGGER.debug(f"Finished install: {wheel_install_result}")
                next_step = "local_model"

        self.install_wheel_task = None
        return self.async_show_progress_done(next_step_id=next_step)

    async def async_step_local_model(
        self, user_input: dict[str, Any] | None = None
    ) -> FlowResult:
        errors = {}
        description_placeholders = {}

        backend_type = self.model_config[CONF_BACKEND_TYPE]
        if backend_type == BACKEND_TYPE_LLAMA_HF:
            schema = STEP_LOCAL_SETUP_DOWNLOAD_DATA_SCHEMA()
        elif backend_type == BACKEND_TYPE_LLAMA_EXISTING:
            schema = STEP_LOCAL_SETUP_EXISTING_DATA_SCHEMA()
        else:
            raise ValueError()

        if self.download_error:
            errors["base"] = "download_failed"
            description_placeholders["exception"] = str(self.download_error)
            schema = STEP_LOCAL_SETUP_DOWNLOAD_DATA_SCHEMA(
                chat_model=self.model_config[CONF_CHAT_MODEL],
                downloaded_model_quantization=self.model_config[CONF_DOWNLOADED_MODEL_QUANTIZATION],
                selected_language=self.selected_language
            )

        if user_input and "result" not in user_input:
            self.selected_language = user_input.pop(CONF_SELECTED_LANGUAGE, self.hass.config.language)

            self.model_config.update(user_input)

            if backend_type == BACKEND_TYPE_LLAMA_HF:
                return await self.async_step_download()
            else:
                model_file = self.model_config[CONF_DOWNLOADED_MODEL_FILE]
                if os.path.exists(model_file):
                    self.model_config[CONF_CHAT_MODEL] = os.path.basename(model_file)
                    return await self.async_step_model_parameters()
                else:
                    errors["base"] = "missing_model_file"
                    schema = STEP_LOCAL_SETUP_EXISTING_DATA_SCHEMA(model_file, self.selected_language)

        return self.async_show_form(
            step_id="local_model", data_schema=schema, errors=errors, description_placeholders=description_placeholders, last_step=False
        )

    async def async_step_download(
      self, user_input: dict[str, Any] | None = None
    ) -> FlowResult:
        if not self.download_task:
            model_name = self.model_config[CONF_CHAT_MODEL]
            quantization_type = self.model_config[CONF_DOWNLOADED_MODEL_QUANTIZATION]

            storage_folder = os.path.join(self.hass.config.media_dirs.get("local", self.hass.config.path("media")), "models")
            self.download_task = self.hass.async_add_executor_job(
                download_model_from_hf, model_name, quantization_type, storage_folder
            )

            return self.async_show_progress(
                progress_task=self.download_task,
                step_id="download",
                progress_action="download",
            )
        
        if self.download_task and not self.download_task.done():
            return self.async_show_progress(
                progress_task=self.download_task,
                step_id="download",
                progress_action="download",
            )

        download_exception = self.download_task.exception()
        if download_exception:
            _LOGGER.info("Failed to download model: %s", repr(download_exception))
            self.download_error = download_exception
            next_step = "local_model"
        else:
            self.model_config[CONF_DOWNLOADED_MODEL_FILE] = self.download_task.result()
            next_step = "model_parameters"

        self.download_task = None
        return self.async_show_progress_done(next_step_id=next_step)
    
    def _validate_text_generation_webui(self, user_input: dict) -> tuple:
        """
        Validates a connection to text-generation-webui and that the model exists on the remote server

        :param user_input: the input dictionary used to build the connection
        :return: a tuple of (error message name, exception detail); both can be None
        """
        try:
            headers = {}
            api_key = user_input.get(CONF_TEXT_GEN_WEBUI_ADMIN_KEY, user_input.get(CONF_OPENAI_API_KEY))
            if api_key:
                headers["Authorization"] = f"Bearer {api_key}"

            models_result = requests.get(
                f"{'https' if self.model_config[CONF_SSL] else 'http'}://{self.model_config[CONF_HOST]}:{self.model_config[CONF_PORT]}/v1/internal/model/list",
                timeout=5, # quick timeout
                headers=headers
            )
            models_result.raise_for_status()

            models = models_result.json()

            for model in models["model_names"]:
                if model == self.model_config[CONF_CHAT_MODEL].replace("/", "_"):
                    return None, None, []

            return "missing_model_api", None, models["model_names"]

        except Exception as ex:
            _LOGGER.info("Connection error was: %s", repr(ex))
            return "failed_to_connect", ex, []
        
    def _validate_ollama(self, user_input: dict) -> tuple:
        """
        Validates a connection to ollama and that the model exists on the remote server

        :param user_input: the input dictionary used to build the connection
        :return: a tuple of (error message name, exception detail); both can be None
        """
        try:
            headers = {}
            api_key = user_input.get(CONF_OPENAI_API_KEY)
            if api_key:
                headers["Authorization"] = f"Bearer {api_key}"

            models_result = requests.get(
                f"{'https' if self.model_config[CONF_SSL] else 'http'}://{self.model_config[CONF_HOST]}:{self.model_config[CONF_PORT]}/api/tags",
                timeout=5, # quick timeout
                headers=headers
            )
            models_result.raise_for_status()

            models = models_result.json()["models"]

            for model in models:
                model_name = self.model_config[CONF_CHAT_MODEL]
                if model["name"] == model_name:
                    return (None, None, [])
                
            return "missing_model_api", None, [x["name"] for x in models]

        except Exception as ex:
            _LOGGER.info("Connection error was: %s", repr(ex))
            return "failed_to_connect", ex, []

    async def async_step_remote_model(
        self, user_input: dict[str, Any] | None = None
    ) -> FlowResult:
        errors = {}
        description_placeholders = {}
        backend_type = self.model_config[CONF_BACKEND_TYPE]
        schema = STEP_REMOTE_SETUP_DATA_SCHEMA(backend_type)

        if user_input:
            try:
                self.selected_language = user_input.pop(CONF_SELECTED_LANGUAGE, self.hass.config.language)

                self.model_config.update(user_input)
                error_message = None

                # validate and load when using text-generation-webui or ollama
                if backend_type == BACKEND_TYPE_TEXT_GEN_WEBUI:
                    error_message, ex, possible_models = await self.hass.async_add_executor_job(
                        self._validate_text_generation_webui, user_input
                    )
                elif backend_type == BACKEND_TYPE_OLLAMA:
                    error_message, ex, possible_models = await self.hass.async_add_executor_job(
                        self._validate_ollama, user_input
                    )
                else:
                    possible_models = []

                if error_message:
                    errors["base"] = error_message
                    if ex:
                        description_placeholders["exception"] = str(ex)
                    schema = STEP_REMOTE_SETUP_DATA_SCHEMA(
                        backend_type,
                        host=user_input[CONF_HOST],
                        port=user_input[CONF_PORT],
                        ssl=user_input[CONF_SSL],
                        chat_model=user_input[CONF_CHAT_MODEL],
                        available_chat_models=possible_models,
                        selected_language=self.selected_language,
                    )
                else:
                    return await self.async_step_model_parameters()

            except Exception:  # pylint: disable=broad-except
                _LOGGER.exception("Unexpected exception")
                errors["base"] = "unknown"

        return self.async_show_form(
            step_id="remote_model", data_schema=schema, errors=errors, description_placeholders=description_placeholders, last_step=False
        )
    
    async def async_step_model_parameters(
        self, user_input: dict[str, Any] | None = None
    ) -> FlowResult:
        errors = {}
        description_placeholders = {}
        backend_type = self.model_config[CONF_BACKEND_TYPE]
        model_name = self.model_config[CONF_CHAT_MODEL].lower()

        selected_default_options = { **DEFAULT_OPTIONS }
        for key in OPTIONS_OVERRIDES.keys():
            if key in model_name:
                selected_default_options.update(OPTIONS_OVERRIDES[key])

        persona = PERSONA_PROMPTS.get(self.selected_language, PERSONA_PROMPTS.get("en"))
        selected_default_options[CONF_PROMPT] = selected_default_options[CONF_PROMPT].replace("<persona>", persona)
        
        schema = vol.Schema(local_llama_config_option_schema(selected_default_options, backend_type))

        if user_input:
            self.options = user_input
            try:
                # validate input
                schema(user_input)
                return await self.async_step_finish()
            except Exception as ex:
                _LOGGER.exception("An unknown error has occurred!")
                errors["base"] = "unknown"

        return self.async_show_form(
            step_id="model_parameters", data_schema=schema, errors=errors, description_placeholders=description_placeholders,
        )

    async def async_step_finish(
        self, user_input: dict[str, Any] | None = None
    ) -> FlowResult:

        model_name = self.model_config.get(CONF_CHAT_MODEL)
        backend = self.model_config[CONF_BACKEND_TYPE]
        if backend == BACKEND_TYPE_LLAMA_EXISTING:
            model_name = os.path.basename(self.model_config.get(CONF_DOWNLOADED_MODEL_FILE))
        location = "llama.cpp" if is_local_backend(backend) else "remote"

        return self.async_create_entry(
            title=f"LLM Model '{model_name}' ({location})",
            description="A Large Language Model Chat Agent",
            data=self.model_config,
            options=self.options,
        )

    @staticmethod
    def async_get_options_flow(
        config_entry: config_entries.ConfigEntry,
    ) -> config_entries.OptionsFlow:
        """Create the options flow."""
        return OptionsFlow(config_entry)


class OptionsFlow(config_entries.OptionsFlow):
    """Local LLaMA config flow options handler."""

    def __init__(self, config_entry: config_entries.ConfigEntry) -> None:
        """Initialize options flow."""
        self.config_entry = config_entry

    async def async_step_init(
        self, user_input: dict[str, Any] | None = None
    ) -> FlowResult:
        """Manage the options."""
        errors = {}
        description_placeholders = {}

        if user_input is not None:
            if not user_input.get(CONF_REFRESH_SYSTEM_PROMPT) and user_input.get(CONF_PROMPT_CACHING_ENABLED):
                errors["base"] = "sys_refresh_caching_enabled"

            if user_input.get(CONF_USE_GBNF_GRAMMAR):
                filename = user_input.get(CONF_GBNF_GRAMMAR_FILE, DEFAULT_GBNF_GRAMMAR_FILE)
                if not os.path.isfile(os.path.join(os.path.dirname(__file__), filename)):
                    errors["base"] = "missing_gbnf_file"
                    description_placeholders["filename"] = filename
            
            if user_input.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES):
                filename = user_input.get(CONF_IN_CONTEXT_EXAMPLES_FILE, DEFAULT_IN_CONTEXT_EXAMPLES_FILE)
                if not os.path.isfile(os.path.join(os.path.dirname(__file__), filename)):
                    errors["base"] = "missing_icl_file"
                    description_placeholders["filename"] = filename

            if len(errors) == 0:
                return self.async_create_entry(title="LLaMA Conversation", data=user_input)
            
        schema = local_llama_config_option_schema(
            self.config_entry.options,
            self.config_entry.data[CONF_BACKEND_TYPE],
        )
        return self.async_show_form(
            step_id="init",
            data_schema=vol.Schema(schema),
            errors=errors,
            description_placeholders=description_placeholders,
        )


def insert_after_key(input_dict: dict, key_name: str, other_dict: dict):
    # if we want to insert them into the above list we need to re-build the dictionary
    result = {}
    for key in input_dict.keys():
        result[key] = input_dict[key]

        if key.schema == key_name:
            for other_key in other_dict.keys():
                result[other_key] = other_dict[other_key]

    return result

def local_llama_config_option_schema(options: MappingProxyType[str, Any], backend_type: str) -> dict:
    """Return a schema for Local LLaMA completion options."""
    if not options:
        options = DEFAULT_OPTIONS

    result = {
        vol.Required(
            CONF_PROMPT,
            description={"suggested_value": options.get(CONF_PROMPT)},
            default=options[CONF_PROMPT],
        ): TemplateSelector(),
        vol.Required(
            CONF_PROMPT_TEMPLATE,
            description={"suggested_value": options.get(CONF_PROMPT_TEMPLATE)},
            default=DEFAULT_PROMPT_TEMPLATE,
        ): SelectSelector(SelectSelectorConfig(
            options=list(PROMPT_TEMPLATE_DESCRIPTIONS.keys()),
            translation_key=CONF_PROMPT_TEMPLATE,
            multiple=False,
            mode=SelectSelectorMode.DROPDOWN,
        )),
        vol.Required(
            CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
            description={"suggested_value": options.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES)},
            default=DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
        ): BooleanSelector(BooleanSelectorConfig()),
        vol.Required(
            CONF_IN_CONTEXT_EXAMPLES_FILE,
            description={"suggested_value": options.get(CONF_IN_CONTEXT_EXAMPLES_FILE)},
            default=DEFAULT_IN_CONTEXT_EXAMPLES_FILE,
        ): str,
        vol.Required(
            CONF_NUM_IN_CONTEXT_EXAMPLES,
            description={"suggested_value": options.get(CONF_NUM_IN_CONTEXT_EXAMPLES)},
            default=DEFAULT_NUM_IN_CONTEXT_EXAMPLES,
        ): NumberSelector(NumberSelectorConfig(min=1, max=16, step=1)),
        vol.Required(
            CONF_MAX_TOKENS,
            description={"suggested_value": options.get(CONF_MAX_TOKENS)},
            default=DEFAULT_MAX_TOKENS,
        ): int,
        vol.Required(
            CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
            description={"suggested_value": options.get(CONF_EXTRA_ATTRIBUTES_TO_EXPOSE)},
            default=DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
        ): TextSelector(TextSelectorConfig(multiple=True)),
        vol.Required(
            CONF_ALLOWED_SERVICE_CALL_ARGUMENTS,
            description={"suggested_value": options.get(CONF_ALLOWED_SERVICE_CALL_ARGUMENTS)},
            default=DEFAULT_ALLOWED_SERVICE_CALL_ARGUMENTS,
        ): TextSelector(TextSelectorConfig(multiple=True)),
        vol.Required(
            CONF_SERVICE_CALL_REGEX,
            description={"suggested_value": options.get(CONF_SERVICE_CALL_REGEX)},
            default=DEFAULT_SERVICE_CALL_REGEX,
        ): str,
        vol.Required(
            CONF_REFRESH_SYSTEM_PROMPT,
            description={"suggested_value": options.get(CONF_REFRESH_SYSTEM_PROMPT)},
            default=DEFAULT_REFRESH_SYSTEM_PROMPT,
        ): BooleanSelector(BooleanSelectorConfig()),
        vol.Required(
            CONF_REMEMBER_CONVERSATION,
            description={"suggested_value": options.get(CONF_REMEMBER_CONVERSATION)},
            default=DEFAULT_REMEMBER_CONVERSATION,
        ): BooleanSelector(BooleanSelectorConfig()),
        vol.Optional(
            CONF_REMEMBER_NUM_INTERACTIONS,
            description={"suggested_value": options.get(CONF_REMEMBER_NUM_INTERACTIONS)},
            default=DEFAULT_REMEMBER_NUM_INTERACTIONS,
        ): int,
    }

    if is_local_backend(backend_type):
        result = insert_after_key(result, CONF_MAX_TOKENS, {
            vol.Required(
                CONF_TOP_K,
                description={"suggested_value": options.get(CONF_TOP_K)},
                default=DEFAULT_TOP_K,
            ): NumberSelector(NumberSelectorConfig(min=1, max=256, step=1)),
            vol.Required(
                CONF_TEMPERATURE,
                description={"suggested_value": options.get(CONF_TEMPERATURE)},
                default=DEFAULT_TEMPERATURE,
            ): NumberSelector(NumberSelectorConfig(min=0, max=3, step=0.05)),
            vol.Required(
                CONF_TOP_P,
                description={"suggested_value": options.get(CONF_TOP_P)},
                default=DEFAULT_TOP_P,
            ): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
            vol.Required(
                CONF_MIN_P,
                description={"suggested_value": options.get(CONF_MIN_P)},
                default=DEFAULT_MIN_P,
            ): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
            vol.Required(
                CONF_TYPICAL_P,
                description={"suggested_value": options.get(CONF_TYPICAL_P)},
                default=DEFAULT_TYPICAL_P,
            ): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
            vol.Required(
                CONF_PROMPT_CACHING_ENABLED,
                description={"suggested_value": options.get(CONF_PROMPT_CACHING_ENABLED)},
                default=DEFAULT_PROMPT_CACHING_ENABLED,
            ): BooleanSelector(BooleanSelectorConfig()),
            vol.Required(
                CONF_PROMPT_CACHING_INTERVAL,
                description={"suggested_value": options.get(CONF_PROMPT_CACHING_INTERVAL)},
                default=DEFAULT_PROMPT_CACHING_INTERVAL,
            ): NumberSelector(NumberSelectorConfig(min=1, max=60, step=1)),
            # TODO: add rope_scaling_type
            vol.Required(
                CONF_CONTEXT_LENGTH,
                description={"suggested_value": options.get(CONF_CONTEXT_LENGTH)},
                default=DEFAULT_CONTEXT_LENGTH,
            ): NumberSelector(NumberSelectorConfig(min=512, max=32768, step=1)),
            vol.Required(
                CONF_BATCH_SIZE,
                description={"suggested_value": options.get(CONF_BATCH_SIZE)},
                default=DEFAULT_BATCH_SIZE,
            ): NumberSelector(NumberSelectorConfig(min=1, max=8192, step=1)),
            vol.Required(
                CONF_THREAD_COUNT,
                description={"suggested_value": options.get(CONF_THREAD_COUNT)},
                default=DEFAULT_THREAD_COUNT,
            ): NumberSelector(NumberSelectorConfig(min=1, max=(os.cpu_count() * 2), step=1)),
            vol.Required(
                CONF_BATCH_THREAD_COUNT,
                description={"suggested_value": options.get(CONF_BATCH_THREAD_COUNT)},
                default=DEFAULT_BATCH_THREAD_COUNT,
            ): NumberSelector(NumberSelectorConfig(min=1, max=(os.cpu_count() * 2), step=1)),
            vol.Required(
                CONF_ENABLE_FLASH_ATTENTION,
                description={"suggested_value": options.get(CONF_ENABLE_FLASH_ATTENTION)},
                default=DEFAULT_ENABLE_FLASH_ATTENTION,
            ): BooleanSelector(BooleanSelectorConfig()),
            vol.Required(
                CONF_USE_GBNF_GRAMMAR,
                description={"suggested_value": options.get(CONF_USE_GBNF_GRAMMAR)},
                default=DEFAULT_USE_GBNF_GRAMMAR,
            ): BooleanSelector(BooleanSelectorConfig()),
            vol.Required(
                CONF_GBNF_GRAMMAR_FILE,
                description={"suggested_value": options.get(CONF_GBNF_GRAMMAR_FILE)},
                default=DEFAULT_GBNF_GRAMMAR_FILE,
            ): str
        })
    elif backend_type == BACKEND_TYPE_TEXT_GEN_WEBUI:
        result = insert_after_key(result, CONF_MAX_TOKENS, {
            vol.Required(
                CONF_CONTEXT_LENGTH,
                description={"suggested_value": options.get(CONF_CONTEXT_LENGTH)},
                default=DEFAULT_CONTEXT_LENGTH,
            ): NumberSelector(NumberSelectorConfig(min=512, max=32768, step=1)),
            vol.Required(
                CONF_TOP_K,
                description={"suggested_value": options.get(CONF_TOP_K)},
                default=DEFAULT_TOP_K,
            ): NumberSelector(NumberSelectorConfig(min=1, max=256, step=1)),
            vol.Required(
                CONF_TEMPERATURE,
                description={"suggested_value": options.get(CONF_TEMPERATURE)},
                default=DEFAULT_TEMPERATURE,
            ): NumberSelector(NumberSelectorConfig(min=0, max=3, step=0.05)),
            vol.Required(
                CONF_TOP_P,
                description={"suggested_value": options.get(CONF_TOP_P)},
                default=DEFAULT_TOP_P,
            ): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
            vol.Required(
                CONF_MIN_P,
                description={"suggested_value": options.get(CONF_MIN_P)},
                default=DEFAULT_MIN_P,
            ): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
            vol.Required(
                CONF_TYPICAL_P,
                description={"suggested_value": options.get(CONF_TYPICAL_P)},
                default=DEFAULT_TYPICAL_P,
            ): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
            vol.Required(
                CONF_REQUEST_TIMEOUT,
                description={"suggested_value": options.get(CONF_REQUEST_TIMEOUT)},
                default=DEFAULT_REQUEST_TIMEOUT,
            ): NumberSelector(NumberSelectorConfig(min=5, max=900, step=1, unit_of_measurement=UnitOfTime.SECONDS, mode=NumberSelectorMode.BOX)),
            vol.Required(
                CONF_REMOTE_USE_CHAT_ENDPOINT,
                description={"suggested_value": options.get(CONF_REMOTE_USE_CHAT_ENDPOINT)},
                default=DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
            ): BooleanSelector(BooleanSelectorConfig()),
            vol.Optional(
                CONF_TEXT_GEN_WEBUI_PRESET,
                description={"suggested_value": options.get(CONF_TEXT_GEN_WEBUI_PRESET)},
            ): str,
            vol.Required(
                CONF_TEXT_GEN_WEBUI_CHAT_MODE,
                description={"suggested_value": options.get(CONF_TEXT_GEN_WEBUI_CHAT_MODE)},
                default=DEFAULT_TEXT_GEN_WEBUI_CHAT_MODE,
            ): SelectSelector(SelectSelectorConfig(
                options=[TEXT_GEN_WEBUI_CHAT_MODE_CHAT, TEXT_GEN_WEBUI_CHAT_MODE_INSTRUCT, TEXT_GEN_WEBUI_CHAT_MODE_CHAT_INSTRUCT],
                translation_key=CONF_TEXT_GEN_WEBUI_CHAT_MODE,
                multiple=False,
                mode=SelectSelectorMode.DROPDOWN,
            )),
        })
    elif backend_type == BACKEND_TYPE_GENERIC_OPENAI:
        result = insert_after_key(result, CONF_MAX_TOKENS, {
            vol.Required(
                CONF_TEMPERATURE,
                description={"suggested_value": options.get(CONF_TEMPERATURE)},
                default=DEFAULT_TEMPERATURE,
            ): NumberSelector(NumberSelectorConfig(min=0, max=3, step=0.05)),
            vol.Required(
                CONF_TOP_P,
                description={"suggested_value": options.get(CONF_TOP_P)},
                default=DEFAULT_TOP_P,
            ): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
            vol.Required(
                CONF_REQUEST_TIMEOUT,
                description={"suggested_value": options.get(CONF_REQUEST_TIMEOUT)},
                default=DEFAULT_REQUEST_TIMEOUT,
            ): NumberSelector(NumberSelectorConfig(min=5, max=900, step=1, unit_of_measurement=UnitOfTime.SECONDS, mode=NumberSelectorMode.BOX)),
            vol.Required(
                CONF_REMOTE_USE_CHAT_ENDPOINT,
                description={"suggested_value": options.get(CONF_REMOTE_USE_CHAT_ENDPOINT)},
                default=DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
            ): BooleanSelector(BooleanSelectorConfig()),
        })
    elif backend_type == BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER:
        result = insert_after_key(result, CONF_MAX_TOKENS, {
            vol.Required(
                CONF_TOP_K,
                description={"suggested_value": options.get(CONF_TOP_K)},
                default=DEFAULT_TOP_K,
            ): NumberSelector(NumberSelectorConfig(min=1, max=256, step=1)),
            vol.Required(
                CONF_TEMPERATURE,
                description={"suggested_value": options.get(CONF_TEMPERATURE)},
                default=DEFAULT_TEMPERATURE,
            ): NumberSelector(NumberSelectorConfig(min=0, max=3, step=0.05)),
            vol.Required(
                CONF_TOP_P,
                description={"suggested_value": options.get(CONF_TOP_P)},
                default=DEFAULT_TOP_P,
            ): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
            vol.Required(
                CONF_USE_GBNF_GRAMMAR,
                description={"suggested_value": options.get(CONF_USE_GBNF_GRAMMAR)},
                default=DEFAULT_USE_GBNF_GRAMMAR,
            ): BooleanSelector(BooleanSelectorConfig()),
            vol.Required(
                CONF_GBNF_GRAMMAR_FILE,
                description={"suggested_value": options.get(CONF_GBNF_GRAMMAR_FILE)},
                default=DEFAULT_GBNF_GRAMMAR_FILE,
            ): str,
            vol.Required(
                CONF_REQUEST_TIMEOUT,
                description={"suggested_value": options.get(CONF_REQUEST_TIMEOUT)},
                default=DEFAULT_REQUEST_TIMEOUT,
            ): NumberSelector(NumberSelectorConfig(min=5, max=900, step=1, unit_of_measurement=UnitOfTime.SECONDS, mode=NumberSelectorMode.BOX)),
            vol.Required(
                CONF_REMOTE_USE_CHAT_ENDPOINT,
                description={"suggested_value": options.get(CONF_REMOTE_USE_CHAT_ENDPOINT)},
                default=DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
            ): BooleanSelector(BooleanSelectorConfig()),
        })
    elif backend_type == BACKEND_TYPE_OLLAMA:
        result = insert_after_key(result, CONF_MAX_TOKENS, {
            vol.Required(
                CONF_CONTEXT_LENGTH,
                description={"suggested_value": options.get(CONF_CONTEXT_LENGTH)},
                default=DEFAULT_CONTEXT_LENGTH,
            ): NumberSelector(NumberSelectorConfig(min=512, max=32768, step=1)),
            vol.Required(
                CONF_TOP_K,
                description={"suggested_value": options.get(CONF_TOP_K)},
                default=DEFAULT_TOP_K,
            ): NumberSelector(NumberSelectorConfig(min=1, max=256, step=1)),
            vol.Required(
                CONF_TEMPERATURE,
                description={"suggested_value": options.get(CONF_TEMPERATURE)},
                default=DEFAULT_TEMPERATURE,
            ): NumberSelector(NumberSelectorConfig(min=0, max=3, step=0.05)),
            vol.Required(
                CONF_TOP_P,
                description={"suggested_value": options.get(CONF_TOP_P)},
                default=DEFAULT_TOP_P,
            ): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
            vol.Required(
                CONF_TYPICAL_P,
                description={"suggested_value": options.get(CONF_TYPICAL_P)},
                default=DEFAULT_TYPICAL_P,
            ): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
            vol.Required(
                CONF_OLLAMA_JSON_MODE,
                description={"suggested_value": options.get(CONF_OLLAMA_JSON_MODE)},
                default=DEFAULT_OLLAMA_JSON_MODE,
            ): BooleanSelector(BooleanSelectorConfig()),
            vol.Required(
                CONF_REQUEST_TIMEOUT,
                description={"suggested_value": options.get(CONF_REQUEST_TIMEOUT)},
                default=DEFAULT_REQUEST_TIMEOUT,
            ): NumberSelector(NumberSelectorConfig(min=5, max=900, step=1, unit_of_measurement=UnitOfTime.SECONDS, mode=NumberSelectorMode.BOX)),
            vol.Required(
                CONF_OLLAMA_KEEP_ALIVE_MIN,
                description={"suggested_value": options.get(CONF_OLLAMA_KEEP_ALIVE_MIN)},
                default=DEFAULT_OLLAMA_KEEP_ALIVE_MIN,
            ): NumberSelector(NumberSelectorConfig(min=-1, max=1440, step=1, unit_of_measurement=UnitOfTime.MINUTES, mode=NumberSelectorMode.BOX)),
            vol.Required(
                CONF_REMOTE_USE_CHAT_ENDPOINT,
                description={"suggested_value": options.get(CONF_REMOTE_USE_CHAT_ENDPOINT)},
                default=DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
            ): BooleanSelector(BooleanSelectorConfig()),
        })

    return result
